Figure 5: Cross-border model generalization

This notebook generates individual panels of Figure 5 in "Combining satellite imagery and machine learning to predict poverty".


In [210]:
from fig_utils import *
import matplotlib.pyplot as plt
import time

%matplotlib inline

Out-of-country performance

In this experiment, we compare the performance of models trained in-country with models trained out-of-country.

The parameters needed to produce the plots for Panels A and B are as follows:

  • country_names: Names of survey data countries
  • country_paths: Paths of directories containing pooled survey data
  • survey: Either 'lsms' or 'dhs'
  • dimension: Number of dimensions to reduce image features to using PCA
  • k: Number of cross validation folds
  • trials: Number of trials to average over
  • points: Number of regularization parameters to try
  • alpha_low: Log of smallest regularization parameter to try
  • alpha_high: Log of largest regularization parameter to try
  • cmap: Color scheme to use for plot, e.g., 'Blues' or 'Greens'

For 10 trials, the LSMS plot should take around 5 minutes and the DHS plot should take around 15 minutes.

Each data directory should contain the following 4 files:

  • conv_features.npy: (n, 4096) array containing image features corresponding to n clusters
  • nightlights.npy: (n,) vector containing the average nightlights value for each cluster
  • households.npy: (n,) vector containing the number of households for each cluster
  • image_counts.npy: (n,) vector containing the number of images available for each cluster

Each data directory should also contain one of the following:

  • consumptions.npy: (n,) vector containing average cluster consumption expenditures for LSMS surveys
  • assets.npy: (n,) vector containing average cluster asset index for DHS surveys

Exact results may differ slightly with each run due to randomly splitting data into training and test sets.

Panel A: LSMS consumption expenditures


In [206]:
# Parameters
country_names = ['nigeria', 'tanzania', 'uganda', 'malawi', 'pooled']
country_paths = ['../data/output/LSMS/nigeria/',
                '../data/output/LSMS/tanzania/',
                '../data/output/LSMS/uganda/',
                '../data/output/LSMS/malawi/',
                '../data/output/LSMS/pooled/']
survey = 'lsms'
dimension = 100
k = 10
trials = 10
points = 30
alpha_low = -2
alpha_high = 5
cmap = 'Greens'

In [207]:
t0 = time.time()
performance_matrix = evaluate_models(country_names, country_paths, survey,
                                     dimension, k, trials, points,
                                     alpha_low, alpha_high, cmap)
t1 = time.time()
print 'Time elapsed: {} seconds'.format(t1-t0)
print 'Corresponding values:'
print performance_matrix


Time elapsed: 245.791417122 seconds
Corresponding values:
[[ 0.42  0.4   0.41  0.33  0.45]
 [ 0.25  0.34  0.37  0.39  0.38]
 [ 0.34  0.38  0.46  0.37  0.48]
 [ 0.43  0.55  0.47  0.48  0.56]
 [ 0.41  0.3   0.26  0.19  0.45]]

Panel B: DHS assets


In [208]:
# Parameters
country_names = ['nigeria', 'tanzania', 'uganda', 'malawi', 'rwanda',
                 'pooled']
country_paths = ['../data/output/DHS/nigeria/',
                '../data/output/DHS/tanzania/',
                '../data/output/DHS/uganda/',
                '../data/output/DHS/malawi/',
                '../data/output/DHS/rwanda/',
                '../data/output/DHS/pooled/']
survey = 'dhs'
dimension = 100
k = 10
trials = 10
points = 30
alpha_low = -2
alpha_high = 5
cmap = 'Blues'

In [209]:
t0 = time.time()
performance_matrix = evaluate_models(country_names, country_paths, survey,
                                     dimension, k, trials, points,
                                     alpha_low, alpha_high, cmap)
t1 = time.time()
print 'Time elapsed: {} seconds'.format(t1-t0)
print 'Corresponding values:'
print performance_matrix


Time elapsed: 928.99625206 seconds
Corresponding values:
[[ 0.5   0.46  0.46  0.41  0.43  0.56]
 [ 0.55  0.68  0.66  0.59  0.74  0.72]
 [ 0.4   0.42  0.43  0.55  0.41  0.48]
 [ 0.56  0.61  0.67  0.55  0.56  0.64]
 [ 0.43  0.58  0.51  0.4   0.46  0.55]
 [ 0.69  0.5   0.46  0.25  0.38  0.66]]

In [ ]: